# files needed:
# het_compiled from blocks -> support, interpolate from utilities, multidim from utilities

using LinearAlgebra
import Base: *, reshape, copy, @property
#include("het_compiled.jl")
#include("interpolate.jl")
#include("multidim.jl")

abstract type LawOfMotion end
    """Abstract class representing a matrix that operates on state space.
    Rather than giant Ns*Ns matrix (even if sparse), some other representation
    almost always desirable; such representations are subclasses of this."""

# don't need to write matmul and T(self)

function lottery_1d(a, a_grid, monotonic=false)
    if !monotonic
        return PolicyLottery1D(interpolate_coord_robust(a_grid, a)..., a_grid)
    else
        return PolicyLottery1D(interpolate_coord(a_grid, a)..., a_grid)
    end
end

""" note from orig code: their code only operates on the final dimension, should make it more general """
# direct translation version; need to check if shapes are stores properly
mutable struct PolicyLottery1D <: LawOfMotion
    i::Array{Float64}  # Assuming i is of type Float64
    pi::Array{Float64}  # Assuming pi is of type Float64
    grid::Array{Float64}  # Assuming grid is of type Float64
    forward::Bool

    flatshape::Tuple{Int} # tuple that stores the flat shape
    original_shape::Tuple{Int} # tuple that stores the original shape
    endog_shape::Tuple{Int} # tuple that stores shape of endogenous grid

    # constructor
    function PolicyLottery1D(i::Array{Float64}, pi::Array{Float64}, grid::Array{Float64}, forward::Bool=true)
        original_shape = size(i) # Calculates original shape
        flatshape = (length(i),)  # Calculate flat shape
        endog_shape = size(grid) # Calculates endog grid shape
        new(i, pi, grid, forward, flatshape, original_shape, endog_shape)
    end
end

# def T(self) function equivalent
function Base.getproperty(p::PolicyLottery1D, symbol::Symbol)
    if symbol == :T
        newself = deepcopy(p)
        newself.forward = !p.forward
        return newself
    end
    return getfield(p, symbol)
end

# matmul operator
function Base.:*(p::PolicyLottery1D, X::Array{Float64})
    if p.forward
        return reshape(het_compiled.forward_policy_1d(reshape(X, p.flatshape), p.i, p.pi), p.original_shape)
    else
        return reshape(het_compiled.expectation_policy_1d(reshape(X, p.flatshape), p.i, p.pi), p.original_shape)
    end
end

mutable struct PolicyLottery1D <: LawOfMotion
    i::Array{Float64}
    pi::Array{Float64}
    grid::Array{Float64}
    forward::Bool

    flatshape::Tuple{Int}
    original_shape::Tuple{Int}
    endog_shape::Tuple{Int}

    # Constructor
    function PolicyLottery1D(i::Array{Float64}, pi::Array{Float64}, grid::Array{Float64}, forward::Bool=true)
        original_shape = size(i)
        endog_shape = size(grid)
        flatshape = (prod(original_shape[1:end-1]), original_shape[end])
        new(i, pi, grid, forward, flatshape, original_shape, endog_shape)
    end

    # T property/method equivalent
    function Base.getproperty(p::PolicyLottery1D, symbol::Symbol)
        if symbol == :T
            newself = deepcopy(p)
            newself.forward = !p.forward
            return newself
        end
        return getfield(p, symbol)
    end

    # matmul equivalent
    function Base.:*(p::PolicyLottery1D, X::Array{Float64})
        reshaped_X = reshape(X, p.flatshape)
        if p.forward
            result = het_compiled.forward_policy_1d(reshaped_X, p.i, p.pi)
        else
            result = het_compiled.expectation_policy_1d(reshaped_X, p.i, p.pi)
        end
        return reshape(result, p.original_shape)
    end
end

#mutable struct ShockedPolicyLottery1D <: PolicyLottery1D end
mutable struct ShockedPolicyLottery1D end

# overload matmul operator for shockedpolicylottery1d
function Base.:*(p::ShockedPolicyLottery1D, X::Array{Float64})
    reshape_X = reshape(X, p.flatshape)
    if p.forward
        result = het_compiled.foward_policy_shock_1d(reshaped_X, p.i, p.pi)
        return reshape(result, p.original_shape)
    else
        throw(NotImplementedError())
    end
end

function lottery_2d(a, b, a_grid, b_grid; monotonic=false)
    if !monotonic
        i_a, pi_a = interpolate_coord_robust(a_grid, a)
        i_b, pi_b = interpolate_coord_robust(b_grid, b)
        return PolicyLottery2d(i_a, pi_a, i_b, pi_b, a_grid, b_grid)
    else # we have no monotonic 2D examples so this statement should not be called
        i_a, pi_a = interpolate_coord(a_grid, a)
        i_b, pi_b = interpolate_coord(b_grid, b)
        return PolicyLottery2D(i_a, pi_a, i_b, pi_b, a_grid, b_grid)
    end
end

abstract type LawOfMotion end

struct PolicyLottery2D <: LawOfMotion
    i1::Array
    i2::Array
    pi1::Array
    pi2::Array
    grid1::Array
    grid2::Array
    forward::Bool
    flatshape::Tuple
    shape::Tuple
    endog_shape::Tuple

    # need to flatten non-policy dims into one
    function PolicyLottery2D(i1, pi1, i2, pi2, grid1, grid2, forward=true) # __init__ func
        flatshape = tuple(prod(size(grid1)), prod(size(grid2))) # reshaping a bunch of stuff
        i1_reshaped = reshape(i1, flatshape)
        i2_reshaped = reshape(i2, flatshape)
        pi1_reshaped = reshape(pi1, flatshape)
        pi2_reshaped = reshape(pi2, flatshape)
        orig_shape = size(i1)
        endog_shape = shape[end-1:end]
        new(i1_reshaped, i2_reshaped, pi1_reshaped, pi2_reshaped, grid1, grid2, forward, flatshape, orig_shape, endog_shape)
    end

    # T(self) func
    function Base.getproperty(self::PolicyLottery2D, sym::Symbol)
        if sym === :T
            return PolicyLottery2D(self.i1, self.pi1, self.i2, self.pi2, self.grid1, self.grid2; forward=!self.forward)
        else
            return getfield(self, sym)
        end
    end

    # matmul operation
    function Base.:*(self::PolicyLottery2D, X::Array)
        flat_X = reshape(X, self.flatshape)
        if self.forward
            return reshape(forward_policy_2d(flat_X, self.i1, self.i2, self.pi1, self.pi2), self.shape)
        else
            return reshape(expectation_policy_2d(flat_X, self.i1, self.i2, self.pi1, self.pi2), self.shape)
        end
    end
end

# Define ShockedPolicyLottery2D
mutable struct ShockedPolicyLottery2D
    i::Array
    pi::Array

    function ShockedPolicyLottery2D(i, pi, grid1, grid2; forward=true)
        flatshape = tuple(prod(size(grid1)), prod(size(grid2)))
        i_reshaped = reshape(i, flatshape)
        pi_reshaped = reshape(pi, flatshape)
        shape = size(i)
        endog_shape = shape[end-1:end]
        new(i_reshaped, pi_reshaped, grid1, grid2, forward, flatshape, shape, endog_shape)
    end

    # matmul operation
    function Base.:*(self::ShockedPolicyLottery2D, X::Array)
        flat_X = reshape(X, self.flatshape)
        if self.forward
            return reshape(forward_policy_shock_2d(flat_X, self.i, self.pi), self.shape)
        else
            throw(ArgumentError("Expectation policy shock not implemented."))
        end
    end
end

# Define the Markov struct for class Markov(LawOfMotion)
struct Markov <: LawOfMotion
    Pi::Array
    i::Int

    # Property to transpose Pi
    function Base.getproperty(self::Markov, sym::Symbol)
        if sym === :T
            newself = copy(self)
            newself.Pi = transpose(newself.Pi)
            return newself
        else
            return getfield(self, sym)
        end
    end

    # Matrix multiplication operation
    function Base.:*(self::Markov, X::Array)
        return multiply_ith_dimension(self.Pi, self.i, X)
    end
end

# class DiscreteChoice(LawOfMotion)
struct DiscreteChoice <: LawOfMotion
    P::Array
    i::Int
    forward::Bool
    P_T::Array

    function DiscreteChoice(P, i; forward=true)
        P_T = permutedims(P, [i+2, 1, 2:i+1, i+3:ndims(P)])
        new(P, i, forward, P_T)
    end

    function Base.getproperty(self::DiscreteChoice, sym::Symbol)
        if sym === :T
            newself = copy(self)
            newself.forward = !self.forward
            return newself
        else
            return getfield(self, sym)
        end
    end

    # Matrix multiplication operation
    function Base.:*(self::DiscreteChoice, X::Array)
        if self.forward
            return batch_multiply_ith_dimension(self.P, self.i, X)
        else
            return batch_multiply_ith_dimension(self.P_T, self.i, X)
        end
    end
end
